
## IMPORT

from pyspark.context import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.context import SQLContext
from pyspark.sql.functions import *
from pyspark.sql.window import Window

## SPARK SESSION 

sc = SparkContext.getOrCreate(conf=SparkConf())
spark = SparkSession(sc)
spark.sparkContext.setLogLevel('WARN')
sc.setCheckpointDir("splink_sandbox/temp_graphframes/")

#### READ L2 ####

l220 = spark.read.parquet('data/l2_summ_2020') \
    .withColumn('cycle', lit(2020)) \
    .withColumn('lat', col('lat').cast('double')) \
    .withColumn('long', col('long').cast('double'))

l216 = spark.read.parquet('data/l2_summ_2016') \
    .withColumn('cycle', lit(2016)) \
    .withColumn('lat', col('lat').cast('double')) \
    .withColumn('long', col('long').cast('double'))

l212 = spark.read.parquet('data/l2_summ_2012') \
    .withColumn('cycle', lit(2012)) \
    .withColumn('lat', col('lat').cast('double')) \
    .withColumn('long', col('long').cast('double'))

l2_all = l220.union(l216).union(l212) \
    .withColumn('cycle', col('cycle').cast('int'))

# pull most recent set of covariates for each unique person
w = Window.partitionBy("component").orderBy(col("cycle").desc())

l2_const =  l2_all \
    .withColumn("row", row_number().over(w)) \
    .filter(col("row") == 1) \
    .drop("row") \
    .select('component', 'education', 'gender', 'age20', 'ethnicity', 'income_est', 'home_val', 'net_worth', 'state_file', 'fips_l2', 'gender_l2', 'lat', 'long') \
    .withColumn('in_l2', lit(1))

l2_incl = l2_all \
    .withColumn('in_l2', lit(1)) \
    .select('component', 'cycle')

# you get a row for every L2 person-year, with invariant characteristics from the most recent year
l2_const = l2_incl.join(l2_const, ['component'], 'left')

#### READ CORELOGIC ####

cl_20 = spark.read.parquet('data/cl_summ_2020')

cl_16 = spark.read.parquet('data/cl_summ_2016')

cl_12 = spark.read.parquet('data/cl_summ_2012')


cl_all = cl_20.union(cl_16).union(cl_12) \
    .withColumn('cycle', col('cycle').cast('int')) \
    .withColumn('in_cl', lit(1)) \
    .withColumn('resid_state', col('resid_state').cast('int'))

fips_cross = spark.read.parquet('joining/state_fips.parq')

cl_all = cl_all.join(fips_cross, ['resid_state'], 'left') \
    .drop('resid_state') \
    .withColumnRenamed('state', 'resid_state')

#### READ FEC ####

fec_all = spark.read.parquet('data/conts_summ') \
    .withColumn('cycle', col('cycle').cast('int')) \
    .withColumn('in_fec', lit(1))


#### FLAG that the cluster had more than one L2 ID previously ####
multi = spark.read.parquet('data/all_components') \
    .withColumnRenamed('id', 'LALVOTERID')

multi = multi.filter(multi.LALVOTERID.startswith('L')) \
    .groupBy('component') \
    .count() \
    .filter('count > 1') \
    .withColumn('multiple', lit(1)) \
    .drop('count')

#### JOIN ALL ####

# create LONG data

out = cl_all.join(l2_const, ['component', 'cycle'], 'full') \
    .join(fec_all, ['component', 'cycle'], 'full') \
    .join(multi, ['component'], 'full')

out = out.na.fill(value=0,subset=['in_cl','in_fec','in_l2', 'multiple'])

out.coalesce(100) \
    .write \
    .partitionBy('cycle') \
    .mode("overwrite") \
    .parquet(f"data/final_long")
